
#include <cstdlib>
#include <fstream>
#include <sstream>
#include <iomanip>
#include <vector>
#include <cmath>
#include <limits>
#include <iostream>
#include <ros/ros.h>
#include <ros/package.h>

#include <opencv2/dnn.hpp>
#include <opencv2/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/imgcodecs/imgcodecs.hpp>


#include <image_transport/image_transport.h>
#include <cv_bridge/cv_bridge.h>

#include <visualization_msgs/Marker.h>
#include <visualization_msgs/MarkerArray.h>

using namespace std;

// object detection

static string yoloClassesFile = "/home/zh/lidarimg_ws/yolo/coco.names";
static string yoloModelConfiguration = "/home/zh/lidarimg_ws/yolo/yolov3.cfg";
static string yoloModelWeights = "/home/zh/lidarimg_ws/yolo/yolov3.weights";
// static string yoloModelWeights = "/home/zh/lidarimg_ws/yolo/yolov3-voc_final.weights";
// static string yoloModelConfiguration = "/home/zh/lidarimg_ws/yolo/yolov3-tiny.cfg";
// static string yoloModelWeights = "/home/zh/lidarimg_ws/yolo/yolov3-tiny.weights";



struct BoundingBox {
    int boxID; 
    int trackID;
    
    cv::Rect roi; 
    int classID; 
    double confidence;
    std::vector<cv::KeyPoint> keypoints; // keypoints enclosed by 2D roi
    std::vector<cv::DMatch> kptMatches; // keypoint matches enclosed by 2D roi
};

void detectObjects(cv::Mat& img, std::vector<BoundingBox>& bBoxes, float confThreshold, float nmsThreshold, 
                std::string classesFile, std::string modelConfiguration, std::string modelWeights, bool bVis)
{
    // load class names from file
    vector<string> classes;
    ifstream ifs(classesFile.c_str());
    string line;
    while (getline(ifs, line)) classes.push_back(line);
    
    // 载入darknet
    cv::dnn::Net net = cv::dnn::readNetFromDarknet(modelConfiguration, modelWeights);
    
    net.setPreferableBackend(cv::dnn::DNN_BACKEND_OPENCV);
    // net.setPreferableTarget(cv::dnn::DNN_TARGET_CPU);
    // net.setPreferableTarget(cv::dnn::DNN_TARGET_OPENCL);
    
    // 4D blob
    cv::Mat blob;
    vector<cv::Mat> netOutput;
    double scalefactor = 1/255.0;
    // cv::Size size = cv::Size(416, 416);
    // cv::Size size = cv::Size(1024, 1024);
    // cv::Size size = cv::Size(608, 608);
    cv::Size size = cv::Size(320, 320);
    // cv::Size size = cv::Size(160, 160);
    cv::Scalar mean = cv::Scalar(0,0,0);
    bool swapRB = true;
    bool crop = false;
    
    blob = cv::dnn::blobFromImage(img,  scalefactor, size, mean, swapRB, crop);
    // cout << blob.flags << "*" << blob.isContinuous() << endl;
    
    vector<cv::String> names;
    vector<int> outLayers = net.getUnconnectedOutLayers();
    vector<cv::String> layersNames = net.getLayerNames();
    
    names.resize(outLayers.size());
    for (size_t i = 0; i < outLayers.size(); ++i){
        names[i] = layersNames[outLayers[i] - 1];
    }

    double t = (double)cv::getTickCount();
    net.setInput(blob);
    net.forward(netOutput, names);
    t = ((double)cv::getTickCount() - t) / cv::getTickFrequency();
    cout <<  1000 * t / 1.0 << " ms" << endl;

    
    // Scan through all bounding boxes and keep only the ones with high confidence
    vector<int> classIds; vector<float> confidences; vector<cv::Rect> boxes;
    for (size_t i = 0; i < netOutput.size(); ++i)
    {
        float* data = (float*)netOutput[i].data;
        for (int j = 0; j < netOutput[i].rows; ++j, data += netOutput[i].cols)
        {
            cv::Mat scores = netOutput[i].row(j).colRange(5, netOutput[i].cols);
            cv::Point classId;
            double confidence;
            
            // Get the value and location of the maximum score
            cv::minMaxLoc(scores, 0, &confidence, 0, &classId);
            if (confidence > confThreshold)
            {
                cv::Rect box; int cx, cy;
                cx = (int)(data[0] * img.cols);
                cy = (int)(data[1] * img.rows);
                box.width = (int)(data[2] * img.cols);
                box.height = (int)(data[3] * img.rows);
                box.x = cx - box.width/2; // left
                box.y = cy - box.height/2; // top
                
                boxes.push_back(box);
                classIds.push_back(classId.x);
                confidences.push_back((float)confidence);
            }
        }
    }
    
    // perform non-maxima suppression
    vector<int> indices;
    cv::dnn::NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
    for(auto it=indices.begin(); it!=indices.end(); ++it) {
        
        BoundingBox bBox;
        bBox.roi = boxes[*it];
        bBox.classID = classIds[*it];
        bBox.confidence = confidences[*it];
        bBox.boxID = (int)bBoxes.size(); // zero-based unique identifier for this bounding box
        if(bBox.classID = 2){
            bBoxes.push_back(bBox);
        }
    }
    
    // show results
    if(bVis) {
        
        cv::Mat visImg = img.clone();
        for(auto it=bBoxes.begin(); it!=bBoxes.end(); ++it) {
            
            // Draw rectangle displaying the bounding box
            int top, left, width, height;
            top = (*it).roi.y;
            left = (*it).roi.x;
            width = (*it).roi.width;
            height = (*it).roi.height;
            cv::rectangle(visImg, cv::Point(left, top), cv::Point(left+width, top+height),cv::Scalar(0, 255, 0), 2);
            
            string label = cv::format("%.2f", (*it).confidence);
            label = classes[((*it).classID)] + ":" + label;
        
            // Display label at the top of the bounding box
            int baseLine;
            cv::Size labelSize = getTextSize(label, cv::FONT_ITALIC, 0.5, 1, &baseLine);
            top = max(top, labelSize.height);
            rectangle(visImg, cv::Point(left, top - round(1.5*labelSize.height)), cv::Point(left + round(1.5*labelSize.width), top + baseLine), cv::Scalar(255, 255, 255), cv::FILLED);
            cv::putText(visImg, label, cv::Point(left, top), cv::FONT_ITALIC, 0.75, cv::Scalar(0,0,0),1);
            
        }
        
        string windowName = "车辆检测";
        cv::namedWindow( windowName, 2 );
        cv::imshow( windowName, visImg );
        cv::waitKey(10);
    }
    


    
}


void imgCallback(const sensor_msgs::ImageConstPtr& imgmsg)
{

    cv_bridge::CvImagePtr cv_ptr;
    cv_ptr = cv_bridge::toCvCopy(imgmsg, sensor_msgs::image_encodings::BGR8);

    cv::Mat cvimg = cv_ptr->image;

    float confThreshold = 0.2;
    float nmsThreshold = 0.4;
    bool bVis = true;

    std::vector<BoundingBox> bboxs;

    string windowName = "图像显示";
    cv::namedWindow( windowName, 1 );
    cv::imshow( windowName, cvimg );
    cv::waitKey(10);

    detectObjects(cvimg, bboxs, confThreshold, nmsThreshold,
                    yoloClassesFile, yoloModelConfiguration, yoloModelWeights, bVis);

    
}

int main(int argc, char** argv)
{
    ros::init(argc, argv, "image_detector");
    ros::NodeHandle nh;

    //图像检测车辆
    image_transport::ImageTransport it(nh);
    image_transport::Subscriber sub_img = it.subscribe("/image_raw", 10, imgCallback);


    // ros::Rate loop_rate(1);
    // while (nh.ok()) 
    // {
    //     ros::spinOnce();
    //     loop_rate.sleep();
    // }

    ros::spin();
}

// 测试用
/*
void pubImage(ros::NodeHandle nh){
  image_transport::ImageTransport it(nh)
  image_transport::Publisher pub = it.advertise("camera/image", 1);
 
  cv::Mat image = cv::imread("/home/leo/Pictures/1.jpeg", CV_LOAD_IMAGE_COLOR);
  if(image.empty())
  {
    printf("open error\n");
  }
  sensor_msgs::ImagePtr msg = cv_bridge::CvImage(std_msgs::Header(), "bgr8", image).toImageMsg();//图像格式转换
 
  ros::Rate loop_rate(5);//每秒5帧
  while (nh.ok()) 
  {
    pub.publish(msg);
    ros::spinOnce();
    loop_rate.sleep();
  }
}
*/
